# -*- coding: utf-8 -*-
"""
Created on Thu Mar 30 23:11:15 2023

@author: 29672366
"""

import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
from utils.DigitEmbedding import DigitEmbedding

num_digits = 8
digit_embedding = DigitEmbedding(num_digits,3)

# Test encoding
input_seq = [19229, 395, 8271]
encoded = digit_embedding.encode(input_seq)
print(encoded)

# Test decoding
decoded = digit_embedding.decode(encoded)
print(decoded)
